import networkx as nx


class State:
    def __init__(self, device, id):
        self.device = device
        self.id = id

    def __str__(self) -> str:
        return '%s:%d' % (self.device, self.id)


class DFA:
    def __init__(self) -> None:
        self.start = State('start', 0)
        self.states = []
        self.transitions = {self.start: {}}
        self.graph = nx.DiGraph()
        self.idx = 1

    def add_state(self, device):
        state = State(device, self.idx)
        self.idx += 1
        self.states.append(state)
        self.transitions[state] = {}
        self.graph.add_node(state)
        return state

    def add_transition(self, fr, to, sym):
        self.transitions[fr][to] = sym
        self.graph.add_edge(fr, to)

    def remove_transition(self, fr, to):
        self.transitions[fr].pop(to)
        self.graph.remove_edge(fr, to)

    def serilize(self, node):
        suffix = []
        self.traverse(node, suffix)
        return ''.join(suffix)


    def traverse(self, node, suffix):
        nexts = self.graph.successors(node)
        nexts = list(nexts)
        nexts = sorted(nexts, key=lambda x: x.device)
        for n in nexts:
            suffix.append(n.device)
            self.traverse(n, suffix)

    def remove_unconncted(self):
        # self.graph = self.graph.to_undirected()
        # for h in nx.connected_components(self.graph.to_undirected()):
        #     print(h)
            # if self.start in h:
            #     print(len(h))
            #     self.graph = nx.DiGraph(self.graph.subgraph(h.nodes).edges)
        nodes = nx.shortest_path(self.graph, self.start).keys()
        # # print(nx.shortest_path(self.graph.to_undirected(), self.start))
        # print(len(nodes))
        self.graph = nx.DiGraph(self.graph.subgraph(nodes).edges)

    def draw(self, g=None):
        import matplotlib.pyplot as plt
        if g == None:
            g = self.graph
        # for s in self.states:
        #     g.add_node('(%s, %s)' % (s.nid, s.sid))
        # for s in self.transitions:
        #     for sym in self.nexts(s):
        #         d = self.next(s, sym)
        #         if d is None:
        #             continue
        #         g.add_edge('(%s, %s)' % (s.nid, s.sid),
        #                    '(%s, %s)' % (d.nid, d.sid), sym = sym.pid)
        nx.draw_shell(g, with_labels=True)
        plt.show()